#!/usr/bin/env python3
import os
import sys

import numpy as np
import pandas as pd
from scipy import stats

CURDIR = os.path.dirname(os.path.realpath(__file__))
sys.path.insert(0, os.path.join(CURDIR, "helpers"))

from pure_http_client import ClickHouseClient


def test_and_check(name, a, b, t_stat, p_value, precision=1e-2):
    client = ClickHouseClient()
    client.query("DROP TABLE IF EXISTS ks_test;")
    client.query("CREATE TABLE ks_test (left Float64, right UInt8) ENGINE = Memory;")
    client.query(
        "INSERT INTO ks_test VALUES {};".format(
            ", ".join(["({},{})".format(i, 0) for i in a])
        )
    )
    client.query(
        "INSERT INTO ks_test VALUES {};".format(
            ", ".join(["({},{})".format(j, 1) for j in b])
        )
    )
    real = client.query_return_df(
        "SELECT roundBankers({}(left, right).1, 16) as t_stat, ".format(name)
        + "roundBankers({}(left, right).2, 16) as p_value ".format(name)
        + "FROM ks_test FORMAT TabSeparatedWithNames;"
    )
    real_t_stat = real["t_stat"][0]
    real_p_value = real["p_value"][0]
    assert (
        abs(real_t_stat - np.float64(t_stat)) < precision
    ), "clickhouse_t_stat {}, scipy_t_stat {}".format(real_t_stat, t_stat)
    assert (
        abs(real_p_value - np.float64(p_value)) < precision
    ), "clickhouse_p_value {}, scipy_p_value {}".format(real_p_value, p_value)
    client.query("DROP TABLE IF EXISTS ks_test;")


def test_ks_all_alternatives(rvs1, rvs2):
    s, p = stats.ks_2samp(rvs1, rvs2)
    test_and_check("kolmogorovSmirnovTest", rvs1, rvs2, s, p)

    s, p = stats.ks_2samp(rvs1, rvs2, alternative="two-sided")
    test_and_check("kolmogorovSmirnovTest('two-sided')", rvs1, rvs2, s, p)

    s, p = stats.ks_2samp(rvs1, rvs2, alternative="greater", method="auto")
    test_and_check("kolmogorovSmirnovTest('greater', 'auto')", rvs1, rvs2, s, p)

    s, p = stats.ks_2samp(rvs1, rvs2, alternative="less", method="exact")
    test_and_check("kolmogorovSmirnovTest('less', 'exact')", rvs1, rvs2, s, p)

    if max(len(rvs1), len(rvs2)) > 10000:
        s, p = stats.ks_2samp(rvs1, rvs2, alternative="two-sided", method="asymp")
        test_and_check("kolmogorovSmirnovTest('two-sided', 'asymp')", rvs1, rvs2, s, p)
        s, p = stats.ks_2samp(rvs1, rvs2, alternative="greater", method="asymp")
        test_and_check("kolmogorovSmirnovTest('greater', 'asymp')", rvs1, rvs2, s, p)


def test_kolmogorov_smirnov():
    rvs1 = np.round(stats.norm.rvs(loc=1, scale=5, size=100), 2)
    rvs2 = np.round(stats.norm.rvs(loc=1.5, scale=5, size=200), 2)
    test_ks_all_alternatives(rvs1, rvs2)

    rvs1 = np.round(stats.norm.rvs(loc=13, scale=1, size=100), 2)
    rvs2 = np.round(stats.norm.rvs(loc=1.52, scale=9, size=100), 2)
    test_ks_all_alternatives(rvs1, rvs2)

    rvs1 = np.round(stats.norm.rvs(loc=1, scale=5, size=100), 2)
    rvs2 = np.round(stats.norm.rvs(loc=11.5, scale=50, size=1000), 2)
    test_ks_all_alternatives(rvs1, rvs2)

    rvs1 = np.round(stats.norm.rvs(loc=1, scale=5, size=11000), 2)
    rvs2 = np.round(stats.norm.rvs(loc=3.5, scale=5.5, size=11000), 2)
    test_ks_all_alternatives(rvs1, rvs2)


if __name__ == "__main__":
    test_kolmogorov_smirnov()
    print("Ok.")
